{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3d77a209",
   "metadata": {},
   "source": [
    "# Loading Data into PGVector\n",
    "\n",
    "The goal of this notebook is to show how to load embeddings from `unstructured` outputs into a Postgre database with the `pgvector` extension installed.\n",
    "The [Postgres documentation](https://www.postgresql.org/docs/15/tutorial-install.html) has instructions on how to install the Postgres database.\n",
    "See [the `pgvector` repo](https://github.com/pgvector/pgvector) for information on how to install `pgvector`.\n",
    "\n",
    "Postgres with `pgvector` is helpful because it combines the capabilities of a vector database with the structured information available in a traditional RDBMS. In this example, we'll show how to:\n",
    "\n",
    "- Load `unstructured` outputs into `pgvector`.\n",
    "- Conduct a similarity search conditioned on a metadata field.\n",
    "- Conduct a similarity search, with a decayed score that biases more recent information."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffbe19a7",
   "metadata": {},
   "source": [
    "## Setup the Postgres Database\n",
    "\n",
    "First, we'll get everything set up for the Postgres database. We'll use `sqlalchemy` as\n",
    "and ORM for defining the table and performing queries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8a538b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import (\n",
    "    create_engine,\n",
    "    ARRAY,\n",
    "    Column,\n",
    "    Integer,\n",
    "    String,\n",
    "    Float,\n",
    "    DateTime,\n",
    "    func,\n",
    "    text,\n",
    ")\n",
    "from pgvector.sqlalchemy import Vector\n",
    "from sqlalchemy.orm import declarative_base, sessionmaker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "91893826",
   "metadata": {},
   "outputs": [],
   "source": [
    "ADA_TOKEN_COUNT = 1536"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6614726b",
   "metadata": {},
   "outputs": [],
   "source": [
    "connection_string = \"postgresql://localhost:5432/postgres\"\n",
    "engine = create_engine(connection_string)\n",
    "\n",
    "Base = declarative_base()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bb2ffda8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Element(Base):\n",
    "    __tablename__ = \"unstructured_elements\"\n",
    "\n",
    "    id = Column(Integer, primary_key=True)\n",
    "    embedding = Column(Vector(ADA_TOKEN_COUNT))\n",
    "    text = Column(String)\n",
    "    category = Column(String)\n",
    "    filename = Column(String)\n",
    "    category = Column(String)\n",
    "    date = Column(DateTime)\n",
    "    sent_to = Column(ARRAY(String))\n",
    "    sent_from = Column(ARRAY(String))\n",
    "    subject = Column(String)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e9130393",
   "metadata": {},
   "outputs": [],
   "source": [
    "Base.metadata.create_all(engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d59aa418",
   "metadata": {},
   "outputs": [],
   "source": [
    "Session = sessionmaker(bind=engine)\n",
    "session = Session()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24fcb4fa",
   "metadata": {},
   "source": [
    "## Preprocess Documents with Unstructured\n",
    "\n",
    "Next, we'll preprocess data (in this case emails) using the `partition_email` function from `unstructured`. We'll also use the `OpenAIEmbeddings` class from `langchain` to create embeddings from the text. The embeddings will be used for similarity search after we've loaded the documents into the database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b08244dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import os\n",
    "\n",
    "from langchain.embeddings.openai import OpenAIEmbeddings\n",
    "from unstructured.partition.email import partition_email"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "de97a526",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXAMPLE_DOCS_DIRECTORY = \"../../example-docs\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5a18fd82",
   "metadata": {},
   "outputs": [],
   "source": [
    "elements = []\n",
    "for f in os.listdir(EXAMPLE_DOCS_DIRECTORY):\n",
    "    if not f.endswith(\".eml\"):\n",
    "        continue\n",
    "\n",
    "    filename = os.path.join(EXAMPLE_DOCS_DIRECTORY, f)\n",
    "    elements.extend(partition_email(filename=filename))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b69915f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_function = OpenAIEmbeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2e6537d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for element in elements:\n",
    "    element.embedding = embedding_function.embed_query(element.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af262aa9",
   "metadata": {},
   "source": [
    "## Load the Documents into Postgres\n",
    "\n",
    "Now that we've preprocessed the documents, we're ready to load the results into the database. We'll do this by creating objects with `sqlalchemy` using the schema we defined early and then running an insert command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a47c99d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "items_to_add = []\n",
    "for element in elements:\n",
    "    items_to_add.append(\n",
    "        Element(\n",
    "            text=element.text,\n",
    "            category=element.category,\n",
    "            embedding=element.embedding,\n",
    "            filename=element.metadata.filename,\n",
    "            date=element.metadata.get_date(),\n",
    "            sent_to=element.metadata.sent_to,\n",
    "            sent_from=element.metadata.sent_from,\n",
    "            subject=element.metadata.subject,\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5d6bbf43",
   "metadata": {},
   "outputs": [],
   "source": [
    "session.add_all(items_to_add)\n",
    "session.commit()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d013b64",
   "metadata": {},
   "source": [
    "## Query the Database\n",
    "\n",
    "Finally, we're ready to query the database. The results from similarity search can be used for retrieval augmented generation, as described in the `langchain` doc [here](https://docs.langchain.com/docs/use-cases/qa-docs). First, we'll run a query conditioned on metadata. In this case, we'll look for similar items, but only look for narrative text elements. You can also perfor this operation using the [`pgvector` vectorstore](https://github.com/hwchase17/langchain/blob/master/langchain/vectorstores/pgvector.py) in `langchain`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7ba10d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector = embedding_function.embed_query(\"email\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "25ed06a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 This is a test email to use for unit tests.\n",
      "13 This is a test email to use for unit tests.\n",
      "5 This is a test email to use for unit tests.\n",
      "9 The unstructured logo is attached to this email.\n",
      "19 It includes:\n"
     ]
    }
   ],
   "source": [
    "query = (\n",
    "    session.query(Element)\n",
    "    .filter(Element.category == \"NarrativeText\")\n",
    "    .order_by(Element.embedding.l2_distance(vector))\n",
    "    .limit(5)\n",
    ")\n",
    "\n",
    "for element in query:\n",
    "    print(element.id, element.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10b0b984",
   "metadata": {},
   "source": [
    "Next, we'll run a similarity search, but add a decay function that biases the results toward most recent documents. This can be helpful if you want to run retrieval augmented generation, but are concerned about passing outdated information into the LLM. In this case, we multiply the distance metric by a decay function with an exponential decay rate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "532cb832",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector = embedding_function.embed_query(\"violets\")\n",
    "decay_rate = 0.10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2ebff5da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0050977532596662945 -  Violets are blue\n",
      "0.001773595479626926 -  Violets are blue\n",
      "0.001773595479626926 -  Violets are blue\n",
      "0.0011421532895244265 -  Roses are red\n",
      "0.00029501066142995373 -  Roses are red\n"
     ]
    }
   ],
   "source": [
    "query = (\n",
    "    session.query(\n",
    "        Element,\n",
    "        Element.text,\n",
    "        func.exp(\n",
    "            text(f\"-{decay_rate} * EXTRACT(DAY FROM (NOW() - date))\")\n",
    "            * Element.embedding.l2_distance(vector)\n",
    "        ).label(\"decay_score\"),\n",
    "    )\n",
    "    .order_by(text(\"decay_score DESC\"))\n",
    "    .limit(5)\n",
    ")\n",
    "\n",
    "for element in query:\n",
    "    print(f\"{element.decay_score} -  {element.text}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edd8ba5f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
